import collections
import logging
import os
import sys
import time

try:
    from multiprocessing import shared_memory
except ImportError:
    logging.warning("shared_memory requires python version >= 3.8")

import numpy as np
import torch

import atorch
from atorch.common.log_utils import default_logger as logger
from atorch.data.data_utils import fast_batch_copy, get_sample_batch
from atorch.data.preloader import data_to_device


class _TensorInfo:
    def __init__(self, dtype, shape):
        self.dtype = dtype
        self.shape = shape


class _TypeDataInfo:
    def __init__(self, element_size):
        self.element_size = element_size
        self.element_num = 0
        self.split_info = []

    def add_item(self, numel):
        self.element_num += numel
        self.split_info.append(numel)


def _write_batch_to_shm(data, tensors, cur_index=None):
    # write batch data into tensors (shared memory)
    if cur_index is None:
        cur_index = {}
        for dtype in tensors:
            cur_index[dtype] = 0

    if isinstance(data, collections.abc.Sequence):
        for v in data:
            _write_batch_to_shm(v, tensors, cur_index)
    elif isinstance(data, collections.abc.Mapping):
        for key in data:
            _write_batch_to_shm(data[key], tensors, cur_index)
    else:
        assert isinstance(data, torch.Tensor), f"should be torch.Tensor, but found {type(data)}"
        index = cur_index[data.dtype]
        cur_index[data.dtype] += 1
        tensors[data.dtype][index][:] = data.reshape(tensors[data.dtype][index].shape)[:]


def _get_batch_with_tensor(data, tensors, cur_index=None):
    """
    data: a data struct with _TensorInfo
    tensors: dict[dtype: tensor_list]
    """
    if cur_index is None:
        cur_index = {}
        for dtype in tensors:
            cur_index[dtype] = 0

    if isinstance(data, collections.abc.Sequence):
        seq_data = []
        for v in data:
            seq_data.append(_get_batch_with_tensor(v, tensors, cur_index))
        if isinstance(data, tuple):
            seq_data = tuple(seq_data)
        return seq_data
    elif isinstance(data, collections.abc.Mapping):
        dict_data = {}
        for key in data:
            dict_data[key] = _get_batch_with_tensor(data[key], tensors, cur_index)
        return dict_data
    else:
        assert isinstance(data, _TensorInfo)
        index = cur_index[data.dtype]
        cur_index[data.dtype] += 1
        tensor = tensors[data.dtype][index].reshape(data.shape)
        return tensor


def _get_batch_info_helper(data, size_info):
    if isinstance(data, collections.abc.Sequence):
        seq_data = []
        for v in data:
            seq_data.append(_get_batch_info_helper(v, size_info))
        if isinstance(data, tuple):
            seq_data = tuple(seq_data)
        return seq_data
    elif isinstance(data, collections.abc.Mapping):
        dict_data = {}
        for key in data:
            dict_data[key] = _get_batch_info_helper(data[key], size_info)
        return dict_data
    elif isinstance(data, torch.Tensor):
        dtype = data.dtype
        numel = torch.numel(data)
        shape = data.shape
        if dtype not in size_info:
            size_info[dtype] = _TypeDataInfo(data.element_size())
        size_info[dtype].add_item(numel)
        data_info = _TensorInfo(dtype, shape)
        return data_info
    else:
        raise ValueError(f"batch data supports for tensor data only, found data type {type(data)}")


def get_batch_info(sample_batch):
    size_info = {}
    data_info = _get_batch_info_helper(sample_batch, size_info)
    return [data_info, size_info]


def _try_get_shm(name, timeout=300):
    for _ in range(timeout):
        try:
            shm = shared_memory.SharedMemory(name=name)
            return shm
        except Exception:
            time.sleep(1)
    return None


def _try_shm_close(shm, unlink=False):
    try:
        shm.close()
        if unlink:
            shm.unlink()
    except Exception:
        pass


class ShmDataContext:
    def __init__(
        self,
        is_master,
        sample_batch,
        rank,
        num_read_per_batch=None,
        num_batch_per_step=None,
        num_coworker_per_node=None,
        shm_name_prefix=None,
        shm_data_size=100,
        need_sync_write=True,
        io_timeout=30,
        initialize_timeout=300,
    ):
        """
        Used in 2 occasions:
        O1. in one node, some gpu processes are used as coworker, other gpu processes as worker.
            coworker processes generate batch data and store in shared memory, and worker processes
            read batch data from shared memory for training. Only worker processes participate in training.
            coworker processes use small ranks, and worker processes use large ranks.
            For example, in a 8-gpu node, 3 gpus for coworker, and 5 gpus for worker. Each worker would
            read batch data from shared memory generated by coworkers.
            Need to set num_read_per_batch=1,
                        num_batch_per_step=worker_num_per_node,
                        num_coworker_per_node=coworker_num_per_node
        O2. in one node, some gpus are doing model parallel, and require same batch for training.
            rank0 read/process data, generate batch, store in shared memory. non-rank0 read batch
            directly from shared memory. All gpus participate in training.
            Need to set num_read_per_batch=non_rank0_worker_num_per_node,
                        num_batch_per_step=1,

        ==Input argument==
        is_master: O1. True for gpu coworker, false for gpu worker.
                   O2. True for mp rank0, false for non-rank0.
        rank: for O1, local_rank in its process group.
              for O2, parallel_rank("model")
        sample_batch: a sample batch to get batch data dtyp, shape, size.
        num_read_per_batch: number of process read a same batch
        num_batch_per_step: number of batch needed for one step
        shm_name_prefix: shared memory name prefix. If None, use "atorch_shm_"
        num_coworker_per_node: for O1, the number of coworker process per node
        shm_data_size: the number of batches can be stored in each shm_group
        need_sync_write: if master need to check storage before write data.
                         Not needed for O2 if shm_data_size > num_workers*prefetch_factor
        io_timeout: number of seconds shared memory read/write timeout, default 30
        initialize_timeout: number of seconds for shm initialization timeout, default 300
        """
        self.is_master = is_master
        self.rank = rank
        self.num_read_per_batch = num_read_per_batch
        self.num_batch_per_step = num_batch_per_step
        self.shm_name_prefix = shm_name_prefix if shm_name_prefix is not None else "atorch_shm_"
        self.num_coworker_per_node = num_coworker_per_node
        self.shm_data_size = shm_data_size
        self.need_sync_write = need_sync_write
        self.io_timeout = io_timeout
        self.initialize_timeout = initialize_timeout
        self.coworker_case = self.num_coworker_per_node is not None  # if is O1
        if is_master:
            self.shm_group_indices = [rank]
        elif self.coworker_case:
            # worker in O1
            self.shm_group_indices = list(range(self.num_coworker_per_node))
        else:
            # non-rank0 in O2
            self.shm_group_indices = [0]
        self.shms = {}
        self.batch_info = get_batch_info(sample_batch)  # [tensor -> _TensorInfo, dict of _TypeDataInfo]
        self.shm_created = False
        # Hack: get rid of user warning from resource tracker as shared memory is only unlinked in master.
        os.environ["PYTHONWARNINGS"] = "ignore::UserWarning"
        self._create_shms()

    def _create_shms(self):
        for idx in self.shm_group_indices:
            self.create_shm_group(idx)
        self.shm_created = True

    def create_shm_group(self, index):
        """A shm group consists of data, state.
        data: dict(dtype: [shm, tensor_list])
            each dtype has a shared mem (size = batch_total_size[dtype] * shm_data_size * num_batch_per_step)
            and a corresponding splitted 1D tensor list.
        state: (shm, np_data).
        """
        shm_name = self.shm_name_prefix + str(index)
        type_data_info = self.batch_info[1]
        data_shm = {}
        for dtype in type_data_info:
            name = shm_name + f"_data_{dtype}"
            ele_count = type_data_info[dtype].element_num * self.shm_data_size * self.num_batch_per_step
            if self.is_master:
                # create shm
                size = type_data_info[dtype].element_size * ele_count
                shm = shared_memory.SharedMemory(create=True, size=size, name=name)
                data_shm[dtype] = shm
            else:
                # get existing shm_group
                shm = _try_get_shm(name, timeout=self.initialize_timeout)
                if shm is None:
                    raise TimeoutError(f"cannot connect to shared memory {name}")
            # create tensor for shared mem
            tensor_data = torch.frombuffer(shm.buf, dtype=dtype, count=ele_count)
            split_tensor_data = tensor_data.split(type_data_info[dtype].element_num)
            data_shm[dtype] = (shm, split_tensor_data)  # shared_mem, tensor_list
        # then for state_shm
        name = shm_name + "_state_shm"
        size = (2 + self.num_read_per_batch * 2) * self.num_batch_per_step
        """state layout:
        FINISH: if no more data.
        WRITE_COUNT: how many batches already written.
        READ_COUNT: how many batches already read.
        STOP: if worker request stop.
        For O1: FINISH * num_batch_per_step | WRITE_COUNT * num_batch_per_step | \
                READ_COUNT * num_batch_per_step | STOP * num_batch_per_step
        For O2: FINISH | WRITE_COUNT | READ_COUNT * num_read_per_batch | STOP * num_read_per_batch
        """
        if self.is_master:
            byte_per_ele = np.dtype(np.int64).itemsize
            shm = shared_memory.SharedMemory(create=True, size=size * byte_per_ele, name=name)
            np_data = np.ndarray([size], dtype=np.int64, buffer=shm.buf)
            for i in range(size):
                np_data[i] = 0
        else:
            shm = _try_get_shm(name, timeout=self.initialize_timeout)
            if shm is None:
                raise TimeoutError(f"cannot connect to shared memory {name}")
            np_data = np.ndarray([size], dtype=np.int64, buffer=shm.buf)
        state_shm = (shm, np_data)

        shm_group = {"data": data_shm, "state": state_shm}
        self.shms[index] = shm_group

    def get_write_count(self, idx=0):
        # get write count list from shms[idx]
        if self.is_master:
            shm_state = self.shms[self.shm_group_indices[0]]["state"][1]
        else:
            shm_state = self.shms[idx]["state"][1]
        write_count = []
        for i in range(self.num_batch_per_step):
            write_count.append(shm_state[self.num_batch_per_step + i])
        return write_count

    def get_read_count(self, idx=0):
        # get read count list from shms[idx]
        if self.is_master:
            shm_state = self.shms[self.shm_group_indices[0]]["state"][1]
        else:
            shm_state = self.shms[idx]["state"][1]
        read_count = []
        for i in range(self.num_batch_per_step * self.num_read_per_batch):
            read_count.append(shm_state[2 * self.num_batch_per_step + i])
        return read_count

    def get_stop_status(self):
        # Return list(bool) for stop status. master only.
        if self.is_master:
            shm_state = self.shms[self.shm_group_indices[0]]["state"][1]
            stop_index = (2 + self.num_read_per_batch) * self.num_batch_per_step
            stop_size = self.num_batch_per_step * self.num_read_per_batch
            stop_status = [shm_state[stop_index + i] > 0 for i in range(stop_size)]
            return stop_status
        else:
            return []

    def set_stop_status(self):
        # non-master only. Set stop status.
        if not self.is_master:
            _, _, _, stop_index = self._get_state_index()
            for idx in self.shm_group_indices:
                shm_state = self.shms[idx]["state"][1]
                shm_state[stop_index] = 1

    def add_batch(self, batches):
        # master only, if batches = None, end of data
        if not self.is_master:
            return
        state_shm = self.shms[self.shm_group_indices[0]]["state"][1]
        if batches is None:
            # no more data, set FINISH
            for idx in range(self.num_batch_per_step):
                state_shm[idx] = 1  # 1 for end of data
        else:
            num = len(batches)
            # the number of batches should be multiple of self.num_batch_per_step.
            assert (
                num % self.num_batch_per_step == 0
            ), f"{num} batches, not multiple of num_batch_per_step({self.num_batch_per_step})"
            # write batches to shm
            step_num = num // self.num_batch_per_step
            data_shm = self.shms[self.shm_group_indices[0]]["data"]
            for step in range(step_num):
                for idx in range(self.num_batch_per_step):
                    batch = batches[step * self.num_batch_per_step + idx]
                    request_stop = False
                    # wait storage is ready if need_sync_write
                    if self.need_sync_write:
                        request_stop = self._sync_on_write(state_shm, idx)
                    if request_stop:
                        break
                    # write batch data to shm
                    self._write_batch(data_shm, state_shm, batch, idx)
                    # update state
                    self._update_state_on_write(state_shm, idx)

    def _sync_on_write(self, state_shm, idx):
        # Return True if worker has requested stop, and write would be skipped.
        current_write_count = state_shm[self.num_batch_per_step + idx]
        if current_write_count < self.shm_data_size:
            # storage not full
            return
        min_read_count = current_write_count - self.shm_data_size + 1
        read_count_index = self.num_batch_per_step * 2 + idx * self.num_read_per_batch
        read_count_size = self.num_read_per_batch
        start_time = time.time()
        while True:
            ready = True
            for i in range(read_count_size):
                if state_shm[read_count_index + i] < min_read_count:
                    ready = False
                    break
            if ready:
                break
            stop_status = self.get_stop_status()
            if any(stop_status):
                return True
            if time.time() - start_time > self.io_timeout:
                raise TimeoutError(f"Timeout on write batch {current_write_count}")
            time.sleep(0.001)
        return False

    def _write_batch(self, data_shm, state_shm, batch, idx):
        current_write_index = state_shm[self.num_batch_per_step + idx] % self.shm_data_size
        if self.coworker_case:
            current_write_index = current_write_index * self.num_batch_per_step + idx
        tensors = {}
        for dtype in data_shm:
            tensor = data_shm[dtype][1][current_write_index]  # get corresponding tensor from tensor_list
            split_info = self.batch_info[1][dtype].split_info
            splitted_tensor = tensor.split(split_info)
            tensors[dtype] = splitted_tensor
        _write_batch_to_shm(batch, tensors)

    def _update_state_on_write(self, state_shm, idx):
        state_shm[self.num_batch_per_step + idx] += 1

    def get_data(self, index=0, device=None):
        # return a batch, or None if no more data. non-master only.
        if self.is_master:
            return
        shm_data = self.shms[index]["data"]
        shm_state = self.shms[index]["state"][1]
        finish_index, write_count_index, read_count_index, _ = self._get_state_index()
        data_index = self._get_data_index(shm_state, finish_index, write_count_index, read_count_index)
        if data_index is None:
            return None
        batch = self._read_batch(shm_data, data_index)
        if device is not None:
            batch = data_to_device(batch, device)
        else:
            batch = fast_batch_copy(batch)
        # update state
        shm_state[read_count_index] += 1
        return batch

    def _read_batch(self, shm_data, data_index):
        # 1. get each dtype tensor, split using split_info.
        tensors = {}
        for dtype in shm_data:
            tensor = shm_data[dtype][1][data_index]  # get corresponding tensor from tensor_list
            split_info = self.batch_info[1][dtype].split_info
            splitted_tensor = tensor.split(split_info)
            tensors[dtype] = splitted_tensor
        # 2. tranverse batch_info, add splitted tensor (after reshape) to create a batch data
        batch = _get_batch_with_tensor(self.batch_info[0], tensors)
        return batch

    def _get_data_index(self, shm_state, finish_index, write_count_index, read_count_index):
        # Wait for data_count is ready, return the data index.
        # If no more data, return None.
        start_time = time.time()
        while True:
            # if shm_state[read_count_index] < 0, wait for master reset
            read_count = shm_state[read_count_index]
            if read_count >= 0:
                if shm_state[write_count_index] > read_count:
                    break
                if shm_state[finish_index] == 1:
                    return None
            if time.time() - start_time > self.io_timeout:
                raise TimeoutError("Timeout in read batch from shared memory")
            time.sleep(0.001)

        index = read_count % self.shm_data_size
        if self.coworker_case:
            return index * self.num_batch_per_step + self.rank
        else:
            return index

    def _get_state_index(self):
        # called by non-master only
        if self.coworker_case:
            finish_index = self.rank
            write_count_index = self.num_batch_per_step + self.rank
            read_count_index = 2 * self.num_batch_per_step + self.rank
            stop_index = 3 * self.num_batch_per_step + self.rank
        else:
            finish_index = 0
            write_count_index = 1
            read_count_index = 1 + self.rank
            stop_index = read_count_index + self.num_read_per_batch
        return finish_index, write_count_index, read_count_index, stop_index

    def wait_worker(self, wait_timeout=100):
        start_time = time.time()
        state_shm = self.shms[self.shm_group_indices[0]]["state"][1]
        timeout = False
        for idx in range(self.num_batch_per_step):
            current_write_count = state_shm[self.num_batch_per_step + idx]
            read_count_size = self.num_read_per_batch
            read_count_index = self.num_batch_per_step * 2 + idx * self.num_read_per_batch
            while True:
                ready = True
                for i in range(read_count_size):
                    if state_shm[read_count_index + i] < current_write_count:
                        ready = False
                        break
                if ready:
                    break
                if time.time() - start_time > wait_timeout:
                    logger.warning(f"timeout in shm_context master wait for worker {idx}")
                    timeout = True
                    break
                time.sleep(1)
            if timeout:
                break

    def reset(self):
        """
        for non-master, set READ_COUNT to -1.
        for master, wait all READ_COUNT < 0, then reset.
        """
        if self.is_master:
            state_shm = self.shms[self.shm_group_indices[0]]["state"][1]
            for idx in range(self.num_batch_per_step):
                finish_index = idx
                write_count_index = self.num_batch_per_step + idx
                read_count_size = self.num_read_per_batch
                read_count_index = self.num_batch_per_step * 2 + idx * self.num_read_per_batch
                stop_index = self.num_batch_per_step * (2 + self.num_read_per_batch) + idx * self.num_read_per_batch
                # wait all READ_COUNT < 0 first
                while True:
                    ready = True
                    for i in range(read_count_size):
                        if state_shm[read_count_index + i] >= 0:
                            ready = False
                            break
                    if ready:
                        break
                    time.sleep(1)
                # reset then
                state_shm[finish_index] = 0
                state_shm[write_count_index] = 0
                for i in range(read_count_size):
                    state_shm[read_count_index + i] = 0
                    state_shm[stop_index + i] = 0
        else:
            for shm_group in self.shms.values():
                _, _, read_count_index, _ = self._get_state_index()
                shm_group["state"][1][read_count_index] = -1

    def tear_down(self, master_wait_for_worker=False, wait_timeout=60):
        # if master_wait_for_worker, master will tear down only after workers have consumed all data.
        # close shm, try unlink if master
        if not self.shm_created:
            return
        stop_status = self.get_stop_status()
        if self.is_master and master_wait_for_worker and not any(stop_status):
            self.wait_worker(wait_timeout=wait_timeout)
        for shm_group in self.shms.values():
            for shm, _ in shm_group["data"].values():
                _try_shm_close(shm, unlink=self.is_master)
            _try_shm_close(shm_group["state"][0], unlink=self.is_master)
        self.shm_created = False


def create_coworker_shm_context(
    sample_batch=None,
    dataset=None,
    dataloader_args=None,
    rank=None,
    group_size=None,
    shm_data_size=100,
    io_timeout=30,
    initialize_timeout=300,
    shm_name_prefix=None,
    need_sync_write=True,
):
    """
    Provide either:
    a. sample_batch
    Or:
    b. dataset + dataloader_args
    """
    if sys.version_info.major < 3 or (sys.version_info.major == 3 and sys.version_info.minor < 8):
        logger.warning("Requires python version >= 3.8 to use shm_context.")
        return None
    if sample_batch is None:
        sample_batch = get_sample_batch(dataset, dataloader_args)
    shm_context = None
    if atorch.distributed.use_coworker():
        if atorch.distributed.is_coworker():
            local_rank = atorch.distributed.rank() % atorch.distributed.coworker_num_per_node()
        else:
            local_rank = atorch.distributed.rank() % atorch.distributed.worker_num_per_node()
        shm_context = ShmDataContext(
            is_master=atorch.distributed.is_coworker(),
            sample_batch=sample_batch,
            rank=local_rank,
            num_read_per_batch=1,
            num_batch_per_step=atorch.distributed.worker_num_per_node(),
            num_coworker_per_node=atorch.distributed.coworker_num_per_node(),
            shm_name_prefix=shm_name_prefix,
            shm_data_size=shm_data_size,
            io_timeout=io_timeout,
            initialize_timeout=initialize_timeout,
            need_sync_write=need_sync_write,
        )
    else:
        shm_context = ShmDataContext(
            is_master=rank == 0,
            sample_batch=sample_batch,
            rank=rank,
            num_read_per_batch=group_size - 1,
            num_batch_per_step=1,
            shm_name_prefix=shm_name_prefix,
            shm_data_size=shm_data_size,
            io_timeout=io_timeout,
            initialize_timeout=initialize_timeout,
            need_sync_write=need_sync_write,
        )
    return shm_context


class ShmData:
    def __init__(
        self,
        name,
        size,
        dtype=np.int64,
        per_rank_data=True,
        local_rank=None,
        local_rank_for_creation=0,
        nproc_per_node=None,
        initialize_timeout=300,
    ):
        """
        name(str): a unique name for ShmData
        size(int): the size of data
        dtype(np dtype): type of data, using numpy type, such as np.int64, np.float32.
        per_rank_data(bool): if True, each rank has its own data partition with size for write;
                             if False, all ranks share same data with size.
        local_rank: if not provided, use atorch.distributed.local_rank()
        local_rank_for_creation: the local_rank for shm creation. Default 0.
        nproc_per_node: if not provided, use atorch.distributed.nproc_per_node()
        initialize_timeout: timeout for shm initialization.
        """
        if local_rank is None:
            local_rank = atorch.distributed.local_rank()
        if nproc_per_node is None:
            nproc_per_node = atorch.distributed.nproc_per_node()
        self.local_rank = local_rank
        self.nproc_per_node = nproc_per_node
        self.name = "local_shm_" + name
        self.dtype = dtype
        self.per_rank_data = per_rank_data
        self.size = size
        self.total_size = size * nproc_per_node if per_rank_data else size
        self.is_creator = local_rank == local_rank_for_creation

        if self.is_creator:
            byte_per_ele = np.dtype(dtype).itemsize
            shm = shared_memory.SharedMemory(create=True, size=self.total_size * byte_per_ele, name=self.name)
            np_data = np.ndarray([self.total_size], dtype=dtype, buffer=shm.buf)
            np_data.fill(0)
        else:
            shm = _try_get_shm(self.name, timeout=initialize_timeout)
            if shm is None:
                raise TimeoutError(f"cannot connect to shared memory {name}")
            np_data = np.ndarray([self.total_size], dtype=dtype, buffer=shm.buf)

        self.shm = shm
        self.np_data = np_data

    def _check_index(self, index):
        if isinstance(index, list):
            if index[0] < 0 or index[0] >= self.size or index[1] > self.size:
                raise ValueError("index outof range")
        else:
            if index < 0 or index >= self.size:
                raise ValueError("index outof range")

    def _get_offset(self, local_rank=None):
        if self.per_rank_data:
            offset = self.size * self.local_rank if local_rank is None else self.size * local_rank
        else:
            offset = 0
        return offset

    def put(self, index, data):
        """
        put data in index. If per_rank_data is True, put in this local_rank's data partition.
        index can be a int, or a range represented a a 2-item list.
        For example:
        put(2, 10) #  shm[2] = 10
        put([5, 7], [20, 30]) # shm[5] = 20, shm[6] = 30
        """
        self._check_index(index)
        offset = self._get_offset()
        if isinstance(index, list):
            for d_idx, idx in enumerate(range(index[0], index[1])):
                self.np_data[offset + idx] = data[d_idx]
        else:
            self.np_data[offset + index] = data

    def get(self, index, local_rank=None):
        """
        get data from index. index can be a int, or a range represented a a 2-item list.
        if per_rank_data is True, get data from local_rank data partion.
        """
        self._check_index(index)
        offset = self._get_offset(local_rank)
        if isinstance(index, list):
            res = []
            for idx in range(index[0], index[1]):
                res.append(self.np_data[offset + idx])
            return res
        else:
            return self.np_data[offset + index]

    def tear_down(self):
        _try_shm_close(self.shm, unlink=self.is_creator)
